import gym
import numpy as np
# from DeliverymanGame.env_init.environment import deliveryManEnv
from env_dqn.environment_dqn import deliveryManEnv
import gym
import itertools
import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
from gym import spaces

import baselines.common.tf_util as U

from baselines import logger
from baselines import deepq
from baselines.deepq.replay_buffer import ReplayBuffer
from baselines.deepq.utils import ObservationInput
from baselines.common.schedules import LinearSchedule
import time


def model(inpt, num_actions, scope, reuse=False):
    """This model takes as input an observation and returns values of all actions."""
    with tf.variable_scope(scope, reuse=reuse):
        out = inpt
        out = layers.fully_connected(out, num_outputs=64, activation_fn=tf.nn.tanh)
        out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None)
        return out

env_dm = deliveryManEnv()

# model = deepq.learn(
#     env_dm,
#     "conv_only",
#     convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
#     hiddens=[256],
#     dueling=True,
#     total_timesteps=0
# )
#


# Create the environment
# env_dm = gym.make("GridWorld-v0")
# Create all the functions necessary to train the model
obs_space = spaces.Box(low=0, high=100000, shape=(144, ))
act, train, update_target, debug = deepq.build_train(
    make_obs_ph=lambda name: ObservationInput(obs_space , name='input'),
    # make_obs_ph=lambda name: ObservationInput(env_dm.observation_space, name=name),
    q_func=model,
    num_actions=5,
    optimizer=tf.train.AdamOptimizer(learning_rate=5e-4),
)
# Create the replay buffer
replay_buffer = ReplayBuffer(50000)
# Create the schedule for exploration starting from 1 (every action is random) down to
# 0.02 (98% of actions are selected according to values predicted by the model).
exploration = LinearSchedule(schedule_timesteps=100000, initial_p=1.0, final_p=0.02)

# Initialize the parameters and copy them to the target network.
U.initialize()
update_target()

episode_rewards = [0.0]
obs = env_dm.reset()
action_set = ['l', 'r', 'u', 'd', 'o']
for t in itertools.count():
    # Take action and update exploration to the newest value
    action = act(obs[None], update_eps=exploration.value(t))[0]
    action = action_set[action]

    new_obs, rew, done = env_dm.step(action)
    # Store transition in the replay buffer.
    replay_buffer.add(obs, action, rew, new_obs, float(done))
    obs = new_obs

    episode_rewards[-1] += rew
    if done:
        obs = env_dm.reset()
        episode_rewards.append(0)

    is_solved = t > 100 and np.mean(episode_rewards[-101:-1]) >= 200
    if is_solved:
        # Show off the result
        env_dm.render()
    else:
        # Minimize the error in Bellman's equation on a batch sampled from replay buffer.
        if t > 1000:
            obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(32)
            train(obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards))
        # Update target network periodically.
        if t % 1000 == 0:
            update_target()

    if done and len(episode_rewards) % 10 == 0:
        logger.record_tabular("steps", t)
        logger.record_tabular("episodes", len(episode_rewards))
        logger.record_tabular("mean episode reward", round(np.mean(episode_rewards[-101:-1]), 1))
        logger.record_tabular("% time spent exploring", int(100 * exploration.value(t)))
        logger.dump_tabular()
    time.sleep(1)
    env_dm.render()
